{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b926628-3855-4629-91b4-07f1577ea30e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Handle data path\n",
    "from pathlib import Path\n",
    "\n",
    "# Read and display data from Physionet\n",
    "import wfdb\n",
    "import pprint\n",
    "import collections\n",
    "from IPython.display import clear_output\n",
    "\n",
    "# Data manipulation and plotting\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.signal import resample\n",
    "from scipy.signal import butter,filtfilt\n",
    "\n",
    "# Divide data into train and test set and save to HDF5\n",
    "import h5py\n",
    "import os\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import preprocessing\n",
    "\n",
    "# Over and undersampling\n",
    "from imblearn.over_sampling import RandomOverSampler\n",
    "from imblearn.under_sampling import RandomUnderSampler\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "import cv2\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e26b4118-f011-40c3-b6c5-c3cdbd431e09",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = Path('./data/2D_BW')\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "# directory that store original MIT-BIH data\n",
    "\n",
    "img_dir = data_dir / 'img'\n",
    "os.makedirs(img_dir, exist_ok=True)\n",
    "\n",
    "label_dir = data_dir / 'label'\n",
    "os.makedirs(label_dir, exist_ok=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e79c784-bcd7-4b0b-a4c6-10adfd75ce83",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert training images to nparray\n",
    "# Convert png files in a directory into the destination nparray\n",
    "def pngConverter (destination, image_dir, if_red_dim):\n",
    "    for root, dirs, files in os.walk(image_dir):\n",
    "        for file in files:\n",
    "            if file.endswith('.png'):\n",
    "                file_path = os.path.join(root,file)\n",
    "                img = cv2.imread(file_path)\n",
    "            if if_red_dim:\n",
    "                img = np.mean(img, axis = 2)\n",
    "            destination.append(img)\n",
    "            \n",
    "train_record_list = [101, 106, 108, 109, 112, 114, 115, 116, 118, 119, 122, 124, 201, 203, 205, 207, 208, 209, 215, 220, 223, 230]\n",
    "test_record_list = [100, 103, 105, 111, 113, 117, 121, 123, 200, 202, 210, 212, 213,214, 219, 221, 222, 228, 231, 232, 233, 234]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2910a8c1-cc19-4cec-a8dd-c3f97e8349a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image = []\n",
    "train_label = []\n",
    "test_image = []\n",
    "test_label = []\n",
    "train_or_test = True\n",
    "\n",
    "for record_number in tqdm(train_record_list, total = len(train_record_list)):\n",
    "    image_dir = img_dir / str(record_number)\n",
    "    pngConverter(train_image, image_dir, True)\n",
    "    label_str = str(record_number)+'.csv'\n",
    "    label_path = label_dir / label_str\n",
    "    train_label.append(np.genfromtxt(label_path,delimiter = ','))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9eff91b-9b73-4f50-aa28-0a204ef0cd9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ratio Split Helper Function\n",
    "import math\n",
    "def train_val_split (data_set, label_set, val_ratio = 0.2, shuffle = True):\n",
    "\n",
    "    n = len(data_set)\n",
    "    label_count = [[] for _ in range(4)]\n",
    "\n",
    "    if shuffle:\n",
    "        index = np.random.permutation(n)\n",
    "    else:\n",
    "        index = np.arange(n)\n",
    "    \n",
    "    data_set =  [data_set[idx] for idx in index]\n",
    "    label_set = [label_set[idx] for idx in index]\n",
    "\n",
    "    for idx, label in enumerate(label_set):\n",
    "        if label != 4:\n",
    "            label_count[int(label)].append(idx)\n",
    "\n",
    "    train_idx = []\n",
    "    val_idx = []\n",
    "\n",
    "    for i in range(len(label_count)):\n",
    "        current_count = len(label_count[i])\n",
    "        print(\"current label is %d with count %d\" % (i, current_count))\n",
    "        split_idx = math.ceil(current_count * val_ratio)\n",
    "        train_idx = train_idx + label_count[i][split_idx:]\n",
    "        val_idx = val_idx + label_count[i][:split_idx]\n",
    "    \n",
    "    return [data_set[idx] for idx in train_idx], [label_set[idx] for idx in train_idx], [data_set[idx] for idx in val_idx], [label_set[idx] for idx in val_idx]\n",
    "\n",
    "\n",
    "train_label_flattened = []\n",
    "for i in range(len(train_label)):\n",
    "    for j in range(len(train_label[i])):\n",
    "        train_label_flattened.append(train_label[i][j])\n",
    "        \n",
    "exp_train_image, exp_train_label, exp_val_image, exp_val_label = train_val_split(train_image, train_label_flattened)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b18b2ff8-8958-42c8-a226-faea7eb185c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.transform import rotate\n",
    "from skimage.util import random_noise\n",
    "\n",
    "# the augmentation method that keeps data from all classes the same augmentation\n",
    "# ratio and number\n",
    "\n",
    "def data_aug_2 (train_image, train_label):\n",
    "    aug_img = []\n",
    "    aug_label = []\n",
    "    for i in range(len(train_image)):\n",
    "        if train_label[i] == 0:\n",
    "            continue\n",
    "        r1_img = rotate(train_image[i], 90)\n",
    "        aug_img.append(r1_img)\n",
    "        aug_label.append(train_label[i])\n",
    "        if train_label[i] == 1:\n",
    "            continue\n",
    "        r2_img = rotate(train_image[i], 180)\n",
    "        aug_img.append(r2_img)\n",
    "        aug_label.append(train_label[i])\n",
    "        r3_img = np.flipud(train_image[i])\n",
    "        aug_img.append(r3_img)\n",
    "        aug_label.append(train_label[i])\n",
    "        if train_label[i] == 2:\n",
    "            continue\n",
    "        r4_img = np.fliplr(train_image[i])\n",
    "        aug_img.append(r4_img)\n",
    "        aug_label.append(train_label[i])\n",
    "\n",
    "    return train_image + aug_img, train_label + aug_label\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "112e0062-a5c5-426e-b7d6-6bc5a1f91a74",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Initial Train set length is %d\" % (len(exp_train_image)))\n",
    "print(\"Validation set length is %d\" % (len(exp_val_image)))\n",
    "\n",
    "aug_img, aug_label = data_aug_2(exp_train_image, exp_train_label)\n",
    "\n",
    "# shuffle\n",
    "index = np.random.permutation(len(aug_img))\n",
    "aug_img =  [aug_img[idx] for idx in index]\n",
    "aug_label = [aug_label[idx] for idx in index]\n",
    "\n",
    "print(len(aug_img))\n",
    "print(len(aug_label))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c1cb642-176a-4a9e-a083-2b17e56f7ad5",
   "metadata": {},
   "source": [
    "# Prepare model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "730f8bcd-9bb0-4022-8df9-ae2dd67e5148",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from scipy.stats import truncnorm\n",
    "from torch.nn.parameter import Parameter\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import torch\n",
    "import h5py\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "650dcc9c-0cbc-4619-baf8-d1d4cc3f100b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# setting device on GPU if available, else CPU\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# a generator for batches of data\n",
    "# yields data (batchsize) and labels (batchsize)\n",
    "# if shuffle is True, it will load batches in a random order\n",
    "def DataBatch(data, label, batchsize, shuffle=True):\n",
    "    n = len(data)\n",
    "    if shuffle:\n",
    "        index = np.random.permutation(n)\n",
    "    else:\n",
    "        index = np.arange(n)\n",
    "    for i in range(int(np.ceil(n/batchsize))):\n",
    "        inds = index[i*batchsize : min(n,(i+1)*batchsize)]\n",
    "        yield [data[idx] for idx in inds], [label[idx] for idx in inds]\n",
    "\n",
    "# evaluate method for cnn\n",
    "def eval_cnn(model, test_data, test_label, minibatch = 100, n=4):\n",
    "    correct=0.\n",
    "    M = np.zeros((n,n))\n",
    "\n",
    "    for i, (data,label) in enumerate(DataBatch(test_data,test_label,minibatch,shuffle=False)):\n",
    "        data = Variable(torch.FloatTensor(np.asarray(data)))\n",
    "        data = data.unsqueeze(1)\n",
    "        data = data.to(torch.device(\"cuda\"))\n",
    "        labels = Variable(torch.LongTensor(np.asarray(label)))\n",
    "        labels = labels.to(torch.device(\"cuda\"))\n",
    "        prediction = model.forward(data)\n",
    "        with torch.no_grad():\n",
    "            numpy_pred = prediction.cpu().numpy()\n",
    "            batch_pred = np.argmax(numpy_pred, axis=1)\n",
    "            correct += np.sum(batch_pred==label)\n",
    "        \n",
    "            for j in range(len(label)):\n",
    "                M[int(label[j]),int(batch_pred[j])] += 1\n",
    "\n",
    "    for i in range(n):\n",
    "        M[i,:] /= np.sum(M[i,:])\n",
    "      \n",
    "    acc = correct/len(test_data)*100\n",
    "    print('Test accuracy is %f' % (acc))\n",
    "    return M, acc\n",
    "\n",
    "\n",
    "# helper function to initialize weight variable\n",
    "def weight_variable(shape):\n",
    "    initial = torch.Tensor(truncnorm.rvs(-1/0.01, 1/0.01, scale=0.01, size=shape))\n",
    "    return Parameter(initial, requires_grad=True)\n",
    "\n",
    "# helper function to initialize bias variable\n",
    "def bias_variable(shape):\n",
    "    initial = torch.Tensor(np.ones(shape)*0.1)\n",
    "    return Parameter(initial, requires_grad=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37da973c-68f1-40ca-a286-c2db6a1096b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_net(model, trainData, trainLabels, valData, valLabels, epochs=10, learnRate = 1e-4, batchSize=50, weights = None):\n",
    "        \n",
    "        if weights:\n",
    "            class_weights = torch.FloatTensor(weights).cuda()\n",
    "            criterion = nn.CrossEntropyLoss(weight=class_weights)\n",
    "        else:\n",
    "            criterion = nn.CrossEntropyLoss()\n",
    "        optimizer = optim.Adam(model.parameters(), lr = learnRate)\n",
    "\n",
    "        train_loss = []\n",
    "        val_loss = []\n",
    "        train_accuracy = []\n",
    "        test_accuracy = []\n",
    "        \n",
    "        \n",
    "        for epoch in tqdm(range(epochs)):\n",
    "            model.to(torch.device(\"cuda\"))\n",
    "            model.train()  # set network in training mode\n",
    "            epoch_train_loss = []\n",
    "            epoch_val_loss = []\n",
    "\n",
    "            for i, (data,labels) in enumerate(DataBatch(trainData, trainLabels, batchSize, shuffle=True)):\n",
    "                data = Variable(torch.FloatTensor(np.asarray(data)))\n",
    "                data = data.unsqueeze(1)\n",
    "                data = data.to(torch.device(\"cuda\"))\n",
    "                labels = Variable(torch.LongTensor(np.asarray(labels)))\n",
    "                labels = labels.to(torch.device(\"cuda\"))\n",
    "                \n",
    "                # Now train the model using the optimizer and the batch data\n",
    "                prediction = model.forward(data)\n",
    "                loss = criterion(prediction, labels)\n",
    "                epoch_train_loss.append(loss.item())\n",
    "                # print('Epoch %d Batch number %d loss: %f' % (epoch+1, i, np.mean(np.array(epoch_train_loss))))\n",
    "                model.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "            model.to(torch.device(\"cuda\"))\n",
    "            model.eval()  # set network in evaluation mode\n",
    "            # validation loss\n",
    "            for i, (val_data, val_labels) in enumerate(DataBatch(valData, valLabels, batchSize, shuffle=False)):\n",
    "                val_data = Variable(torch.FloatTensor(np.asarray(val_data)))\n",
    "                val_data = val_data.unsqueeze(1)\n",
    "                val_data = val_data.to(torch.device(\"cuda\"))\n",
    "                val_labels = Variable(torch.LongTensor(np.asarray(val_labels)))\n",
    "                val_labels = val_labels.to(torch.device(\"cuda\"))\n",
    "                with torch.no_grad():\n",
    "                    prediction = model.forward(val_data)\n",
    "                    loss = criterion(prediction, val_labels)\n",
    "                    epoch_val_loss.append(loss.item())\n",
    "\n",
    "            epoch_mean_val_loss = np.mean(np.array(epoch_val_loss))\n",
    "            val_loss.append(epoch_mean_val_loss)\n",
    "                \n",
    "            epoch_mean_train_loss = np.mean(np.array(epoch_train_loss))\n",
    "            train_loss.append(epoch_mean_train_loss)\n",
    "            \n",
    "        print ('Epoch:%d train loss: %f val loss: %f'%(epoch+1, epoch_mean_train_loss, epoch_mean_val_loss))\n",
    "        \n",
    "        return train_loss, val_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c89b2831-fd85-4911-93bb-a5c057f558d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "resnet18 = models.resnet18(pretrained=True)\n",
    "resnet18.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)\n",
    "resnet18.fc = nn.Linear(512, 4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d6e3ed5-80d6-4702-9bff-4d81dd39dbb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(resnet18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93cd0bed-b0b1-49dc-a4e5-0a628af85a88",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_loss, val_loss = train_net(resnet18, aug_img, aug_label, exp_val_image, exp_val_label, epochs=10, batchSize=128)\n",
    "torch.save(resnet18.state_dict(), Path('./model/exp_resnet18_pretrained_aug.pt'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4c6a843-dbbb-411a-aee0-1ae3ff860dcc",
   "metadata": {},
   "source": [
    "# Test model\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa7be8c6-20af-467a-8c19-9787298bceac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Handle data path\n",
    "from pathlib import Path\n",
    "\n",
    "# Read and display data from Physionet\n",
    "import wfdb\n",
    "import pprint\n",
    "import collections\n",
    "from IPython.display import clear_output\n",
    "\n",
    "# Data manipulation and plotting\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.signal import resample\n",
    "from scipy.signal import butter,filtfilt\n",
    "\n",
    "# Divide data into train and test set and save to HDF5\n",
    "import h5py\n",
    "import os\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import preprocessing\n",
    "\n",
    "# Over and undersampling\n",
    "from imblearn.over_sampling import RandomOverSampler\n",
    "from imblearn.under_sampling import RandomUnderSampler\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "import cv2\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from torch import nn\n",
    "from scipy.stats import truncnorm\n",
    "from torch.nn.parameter import Parameter\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee8cda12-2889-4ad2-97ad-cbb4eb085d62",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = Path('./data/2D_BW')\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "\n",
    "# directory that store original MIT-BIH data\n",
    "\n",
    "img_dir = data_dir / 'img'\n",
    "os.makedirs(img_dir, exist_ok=True)\n",
    "\n",
    "label_dir = data_dir / 'label'\n",
    "os.makedirs(label_dir, exist_ok=True)\n",
    "\n",
    "# Convert training images to nparray\n",
    "# Convert png files in a directory into the destination nparray\n",
    "def pngConverter (destination, image_dir, if_red_dim):\n",
    "    for root, dirs, files in os.walk(image_dir):\n",
    "        for file in files:\n",
    "            if file.endswith('.png'):\n",
    "                file_path = os.path.join(root,file)\n",
    "                img = cv2.imread(file_path)\n",
    "            if if_red_dim:\n",
    "                img = np.mean(img, axis = 2)\n",
    "            destination.append(img)\n",
    "            \n",
    "train_record_list = [101, 106, 108, 109, 112, 114, 115, 116, 118, 119, 122, 124, 201, 203, 205, 207, 208, 209, 215, 220, 223, 230]\n",
    "test_record_list = [100, 103, 105, 111, 113, 117, 121, 123, 200, 202, 210, 212, 213,214, 219, 221, 222, 228, 231, 232, 233, 234]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e0b662-8613-427a-9965-d6c343d6aa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_label = []\n",
    "test_image = []\n",
    "\n",
    "for record_number in tqdm(test_record_list, total = len(test_record_list)):\n",
    "    image_dir = img_dir / str(record_number)\n",
    "    pngConverter(test_image, image_dir, True)\n",
    "    label_str = str(record_number)+'.csv'\n",
    "    label_path = label_dir / label_str\n",
    "    test_label.append(np.genfromtxt(label_path,delimiter = ','))\n",
    "\n",
    "test_label_flattened = []\n",
    "for i in range(len(test_label)):\n",
    "    for j in range(len(test_label[i])):\n",
    "        test_label_flattened.append(test_label[i][j])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20167990-62ab-4725-9883-d86e220ed3a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_arr = np.array(test_label_flattened)\n",
    "index = np.where(np.array(label_arr) == 4)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28b596c1-2121-44a0-81ea-80947e7f8e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_label_flattened = [i for j, i in enumerate(test_label_flattened) if j not in index]\n",
    "test_image = [i for j, i in enumerate(test_image) if j not in index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f06b76f2-e8e3-4a16-b60f-ec0fa9352785",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.unique(test_label_flattened))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b7b77e1-1b60-4f3a-8c82-d7fc530e579d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(test_image), test_image[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8712f177-83e2-4f47-9834-cdeb208f6f73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import plot_confusion_matrix\n",
    "from sklearn.metrics import ConfusionMatrixDisplay\n",
    "import torchvision.models as models\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "\n",
    "def DataBatch(data, label, batchsize, shuffle=True):\n",
    "    n = len(data)\n",
    "    if shuffle:\n",
    "        index = np.random.permutation(n)\n",
    "    else:\n",
    "        index = np.arange(n)\n",
    "    for i in range(int(np.ceil(n/batchsize))):\n",
    "        inds = index[i*batchsize : min(n,(i+1)*batchsize)]\n",
    "        yield [data[idx] for idx in inds], [label[idx] for idx in inds]\n",
    "\n",
    "def test_model(model, test_data, test_label, batch_size=64, n=4):\n",
    "    \"\"\"\n",
    "    This function will run test of the model on the test dataset and return \n",
    "        - classification report string (for display purpose)\n",
    "        - dictionary of classification report (for query purpose)\n",
    "        - confusion matrix\n",
    "    \"\"\"\n",
    "    \n",
    "    predictions = []\n",
    "    labels = []\n",
    "    model.to(device)\n",
    "    \n",
    "    for i, (data,label) in enumerate(DataBatch(test_data,test_label,batch_size,shuffle=False)):\n",
    "        data = Variable(torch.FloatTensor(np.asarray(data)))\n",
    "        data = data.unsqueeze(1)\n",
    "        data = data.to(torch.device(\"cuda\"))\n",
    "        prediction = model.forward(data)\n",
    "        with torch.no_grad():\n",
    "            data, label = data.to(device), label\n",
    "            predictions += list(np.argmax(model(data).cpu().numpy(), axis=1))\n",
    "            labels += list(label)\n",
    "            \n",
    "    predictions = np.array(predictions)\n",
    "    labels = np.array(labels)\n",
    "        \n",
    "    target_names = ['N', 'S', 'V', 'F']\n",
    "    report = classification_report(labels, predictions, target_names=target_names, digits=3)\n",
    "    report_dict = classification_report(labels, predictions, target_names=target_names, output_dict=True)\n",
    "    c_matrix = confusion_matrix(labels, predictions)\n",
    "    return report, report_dict, c_matrix\n",
    "\n",
    "resnet18 = models.resnet18(pretrained=False)\n",
    "resnet18.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)\n",
    "resnet18.fc = nn.Linear(512, 4)\n",
    "resnet18.load_state_dict(torch.load(Path('./model/exp_resnet18_pretrained_aug.pt')))\n",
    "resnet18.eval()\n",
    "\n",
    "report, report_dict, c_matrix = test_model(resnet18, test_image, test_label_flattened)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "136fbc6a-cd55-4d78-b789-d1e1aef94c4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Training result:\\n', report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beae2e2c-ecd5-4704-9ea8-fd27d96687f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "## display confusion matrix\n",
    "display_labels = ['N', 'S', 'V', 'F']\n",
    "\n",
    "def plot_confusion_matrix(cm, classes,\n",
    "                          normalize=False,\n",
    "                          title='Confusion matrix',\n",
    "                          cmap=plt.cm.Blues):\n",
    "    \"\"\"\n",
    "    This function prints and plots the confusion matrix.\n",
    "    Normalization can be applied by setting `normalize=True`.\n",
    "    \"\"\"\n",
    "    if normalize:\n",
    "        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
    "        print(\"Normalized confusion matrix\")\n",
    "    else:\n",
    "        print('Confusion matrix, without normalization')\n",
    "\n",
    "    plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
    "    plt.title(title)\n",
    "    plt.colorbar()\n",
    "    tick_marks = np.arange(len(classes))\n",
    "    plt.xticks(tick_marks, classes, rotation=45)\n",
    "    plt.yticks(tick_marks, classes)\n",
    "\n",
    "    fmt = '.2f' if normalize else 'd'\n",
    "    thresh = cm.max() / 2.\n",
    "    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
    "        plt.text(j, i, format(cm[i, j], fmt),\n",
    "                 horizontalalignment=\"center\",\n",
    "                 color=\"white\" if cm[i, j] > thresh else \"black\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.ylabel('True label')\n",
    "    plt.xlabel('Predicted label')\n",
    "    plt.show()\n",
    "    plt.clf()\n",
    "    \n",
    "plot_confusion_matrix(c_matrix, display_labels ,\n",
    "                      title='Normalzied Confusion Matrix', normalize=True, cmap='Greys')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
